library(tidyverse)
library(ComplexHeatmap)
library(circlize)
library(pheatmap)
library(corrplot)
library(ggplot2)
library(reshape2)
library(grid)
library(dplyr)


# Load the tables before running the R script.

#------------------- start -----------------------

# col_order
col_order = c(
  "Bf_21", "Bf_31", "Bf_33", "Bf_16", "Bf_30", "Bf_35",
  "Bf_6",
  "Bf_17",
  "Bf_28", "Bf_15",
  "Bf_18",
  "Bf_34",
  "Bf_13", "Bf_3", "Bf_8",
  "Bf_39", "Bf_36", "Bf_9", "Bf_22", "Bf_7", "Bf_1",
  "Bf_26", "Bf_24", "Bf_25",
  "Bf_5",  "Bf_37", "Bf_20", "Bf_27", "Bf_10",
  "Bf_14", "Bf_32",
  "Bf_29",
  "Bf_11", "Bf_38", "Bf_0", "Bf_23", "Bf_2", "Bf_19", "Bf_4", "Bf_12"
)
column_groups = c("Non-neuron", "Neuron", "Neuron", "Non-neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", 
                  "Neuron", "Non-neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", 
                  "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", 
                  "Neuron", "Neuron", "Non-neuron", "Neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron")


top_anno <- HeatmapAnnotation(
  type = column_groups,
  col = list(type = c("Neuron" = "orchid", "Non-neuron" = "lightgreen")),
  annotation_name_gp = gpar(fontsize = 8),
  simple_anno_size = unit(2.5, "mm")
)

# generate heatmap
heatmap_list <- lapply(seq_along(df_filtered_list), function(i) {
  df <- df_filtered_list[[i]]
  Heatmap(as.matrix(df),
          top_annotation = if (i == 1) c(top_anno) else NULL,
          col = colorRamp2(seq(0, 1, length.out = 25), (hcl.colors(25, "Oslo"))),
          name = stageB[i],
          column_names_rot = -90,
          column_order = col_order, 
          row_names_side = "right",
          row_order = rownames(df[i]),
          show_row_dend = TRUE,
          width = ncol(df) * unit(2.8, "mm"),
          height = nrow(df) * unit(2.8, "mm"),
          rect_gp = gpar(col = "#ccc", lwd = 1),
          row_names_gp = gpar(fontsize = 8),
          column_names_gp = gpar(fontsize = 8),
          show_heatmap_legend = TRUE,
          cell_fun = function(j, i, x, y, width, height, fill) {
            if (df[i, j] > 0.2) {
              grid.points(x, y, pch = 16, size = unit(1, "mm"), gp = gpar(col = "#ff3"))
            }
          },
          border = TRUE)
}) 
ht_list <- Reduce(`%v%`, heatmap_list) 
draw(ht_list)


ht_drawn = draw(ht_list,
                show_heatmap_legend = FALSE,
                show_annotation_legend = FALSE)


w = ComplexHeatmap:::width(ht_drawn)
h = ComplexHeatmap:::height(ht_drawn)

w_inch = convertWidth(w, "inch", valueOnly = TRUE)
h_inch = convertHeight(h, "inch", valueOnly = TRUE)

pdf("./Mouse_Class vs Bf snRNAseq-1.pdf", width = w_inch, height = h_inch)
draw(ht_list,
     show_heatmap_legend = FALSE,
     show_annotation_legend = FALSE)
dev.off()